Build your own tensor type

We present here a core concept of the PySyft library. It is the ability to add new custom tensor types that can provide specific behaviors such as encryption or traceability. This feature makes our library very generic and completely open to new innovations in the field of privacy-preserving machine learning.

We will go through a very simple example which could be the base for a traceability feature that would keep track of the operations performed on the data in a verifiable way. This new tensor type will log all operations executed on tensors of its type. Let's call this type the CustomLoggingTensor.


0. Preliminaries

We use the sandbox that we have already discovered.

In [1]:
import torch as th
import syft as sy
sy.create_sandbox(globals(), verbose=False)

Setting up Sandbox...

Let's first recall the notions of Torch and Syft tensors. All the object the end user manipulates are torch tensors. This is of course the case when it's a pure torch tensor (ex: x = th.tensor([1., 2])), but also when you deal with syft objects, such as the pointer tensor which is a particular case of syft tensor.

In [2]:
ptr = th.tensor([1., 2]).send(bob)

(Wrapper)>[PointerTensor | me:49975446348 -> bob:27355730204]

The wrapper object you see is actually an empty torch tensor with a child argument which is a PointerTensor:

In [3]:
isinstance(ptr, th.Tensor)


In [4]:


This is also true for more complex objects, where you also see this wrapper at the beginning. You can then have multiple Syft or Torch tensors chained through the .child attribute.

In [5]:
x = th.tensor([1., 2]).fix_prec().share(alice, bob)

	-> [PointerTensor | me:97109067924 -> alice:10104773986]
	-> [PointerTensor | me:37015083947 -> bob:93808751958]
	*crypto provider: me*

In [6]:

	-> [PointerTensor | me:97109067924 -> alice:10104773986]
	-> [PointerTensor | me:37015083947 -> bob:93808751958]
	*crypto provider: me*

In [7]:

{'alice': [PointerTensor | me:97109067924 -> alice:10104773986],
 'bob': [PointerTensor | me:37015083947 -> bob:93808751958]}

Recall that the general behaviour is the following: each time a command in called on the top object, it goes down the chain where it can be modified, it is then executed at the bottom and the result is wrapped back to have exactly the some chain structure, to keep the same properties (such as traceability for example).

What we're going to do here is to create our own syft Tensor type that we will be able to put in this chain!

1. The MVP of the CustomLoggingTensor

1.1 Declare the class type

To get started, there isn't much things to do. First, we need to create the tensor class.

This is done in syft/frameworks/torch/tensors/, choose the folder:

  • interpreters if the functionality you want to build will modify the results or functions, or
  • decorators if the functionality is just ... decorative.

The interpreters / decorators might be removed in the future, in which case just put your tensor in syft/frameworks/torch/tensors/

Here we'll put it in the decorator folder. Choose a simple but explicit name, for now decorators/ will be sufficient.

Write there the minimal class definition, where our tensor inherits from AbstractTensor, an abstraction which gives default behaviours to Syft tensors:

In [8]:
from syft.generic.abstract.tensor import AbstractTensor

class CustomLoggingTensor(AbstractTensor):
    def __init__(self, **kwargs):

This was quite fast, wasn't it?

1.2 Add to the hooks

1.2.1 Allow imports

You now need to declare this type in the imports so that you can use it in real. Add it in the files:

- syft/frameworks/torch/tensors/[decorators|interpreters]/
- syft/

You should now be able to import the tensor type: from syft import CustomLoggingTensor

1.2.2 Hook tensor to add torch operations

In the syft/frameworks/torch/hook/ file:

  • Add an import line at the top
  • Add the following in the TorchHook __init__: self._hook_syft_tensor_methods(CustomLoggingTensor)

All instances of CustomLoggingTensor now have for example a .add(...) method. We should now explain how to use it.

1.2.3 Hook tensor to correctly forward arguments to the torch operations

In particular we would like that arguments provided as CustomLoggingTensor be unwrapped and replaced with their .child attribute, do go down the chain.

In the syft/generic/frameworks/hook/ file:

  • Add an import line at the top
  • Extend the type_rule dict with CustomLoggingTensor: one, (means that this type of tensors supports (un)wrapping)
  • Extend the forward_func dict with CustomLoggingTensor: lambda i: i.child, (explains how to unwrap)
  • Extend the backward_func dict with CustomLoggingTensor: lambda i, **kwargs: CustomLoggingTensor(**kwargs).on(i, wrap=False), (explains how to wrap)

Et voilà! You can already do many things with your new tensor!

In [9]:
x = CustomLoggingTensor()


Ok this is not super useful, but it comes with a .on method which works as follow:

In [10]:
x = th.tensor([1., 2])
x = CustomLoggingTensor().on(x)

(Wrapper)>CustomLoggingTensor>tensor([1., 2.])

.on simply inserts the tensor node into a tensor chain. As we always need to have a torch tensor at the top of the chain, a wrapper was automatically added.

Ready to use!

As this point, if you want to have the behaviour desired, you should make the code changes in the repository: integrating the code in the repository allows you to benefit from the hooking functionalities. In particular, after re-instantiating the hook, your CustomLoggingTensor should have the methods a pure torch tensor has.

Make the change and re-run the notebook up to here

This time, we add the sy. meaning the code is from the repo.

In [11]:
x = th.tensor([1., 2])
x = sy.CustomLoggingTensor().on(x)

(Wrapper)>CustomLoggingTensor>tensor([1., 2.])

You can do computations on this chain such as x * 2, and for example the call __mul__ made will be forwarded all through the chain down to the last node which is a pure torch tensor, whose value is doubled.

In [12]:
x * 2

(Wrapper)>CustomLoggingTensor>tensor([2., 4.])

If you correctly obtained (Wrapper)>CustomLoggingTensor>tensor([2., 4.]), you're all set!

2. Adding special functionalities

Now that you have defined your own tensor type, you should specify it's behaviour, as by default it won't do anything thing special and will just act passively.

In this part, we will see how to specify custom functionalities. We'll use for the execution parts the already existing LoggingTensor instead of the CustomLoggingTensor and highlight which part of code produces which functionalities, so that you can run code in this notebook without reloading the kernel. If you want to practice more, you can report the code changes in the CustomLoggingTensor class definition and you'll observe the same behaviours (just reload the notebook each time you perform a modification in the library code)

In [13]:
from syft import LoggingTensor

2.1 Default behaviour for functions

You can add a special functionality each time a (hooked) torch function is called on LoggingTensor: here we just print the call.

Note that this is for functions exclusively and not for methods, but applies for all hooked torch functions

In [14]:
class CustomLoggingTensor(AbstractTensor):
    def __init__(self, **kwargs):

    def on_function_call(cls, command):
        Override this to perform a specific action for each call of a torch
        function with arguments containing syft tensors of the class doing
        the overloading
        cmd, _, args, kwargs = command
        print("Default log", cmd)

In [15]:
x = th.tensor([1., 2])
x = LoggingTensor().on(x)

th.div(x, x)
th.nn.functional.celu(x) # celu is a variant of the activation function relu(x) = max(0, x)

Default log torch.div
Default log torch.nn.functional.celu
(Wrapper)>LoggingTensor>tensor([1., 2.])

Note: this on_function_call is called by handle_func_command which comes from the AbstractTensor: it explains how to propagate functions down the chain, and in some cases you might also need to change it.

2.2 Overloading torch methods

We introduce here an important decorator object which is @overloaded:

In [16]:
from syft.generic.frameworks.overload import overloaded

You can directly overwrite torch methods like this, where we overload the .add method so that we first print that it was called and then forward the call to the .child attributes.

In [17]:
class CustomLoggingTensor(AbstractTensor):
    def __init__(self, **kwargs):
    def add(self, _self, *args, **kwargs):
        print("Log method add")
        response = _self.add(*args, **kwargs)
        return response

Here is an example of how to use the @ overloaded.method decorator. To see what this decorator do, just look at the next method manual_add: it does exactly the same but without the decorator.

Note the subtlety between self and _self: you should use _self and NOT self. We kept self because it can hold useful attributes that you might want to access (for example, for the fixed precision tensor it stores the field size)

Here is the version of the add method without the decorator: as you can see it is much more complicated. However you might need sometimes to use it to specify some particular behaviour: so here where to start from if needed!

In [18]:
class CustomLoggingTensor(AbstractTensor):
    # [...]
    def manual_add(self, *args, **kwargs):
        # Replace all syft tensor with their child attribute
        new_self, new_args, new_kwargs = syft.generic.frameworks.hook.hook_args.hook_method_args(
            "add", self, args, kwargs

        print("Log method manual_add")
        # Send it to the appropriate class and get the response
        response = getattr(new_self, "add")(*new_args, **new_kwargs)

        # Put back SyftTensor on the tensors found in the response
        response = syft.generic.frameworks.hook.hook_args.hook_response(
            "add", response, wrap_type=type(self)
        return response

They behave exactly the same and print a line when called

In [19]:
x = LoggingTensor().on(th.tensor([1., 2]))

r = x.add(x)

(Wrapper)>LoggingTensor>tensor([1., 2.])
Log method add

You might want to try to run r = x.manual_add(x) but this will fail: if the LoggingTensor which is x.child had a .manual_add method, this is not the case for the wrapper x as torch tensor don't have .manual_add by default.

2.3 Overloading torch functions

We will still use the @overloaded decorator but now with:

- @overloaded.module
- @overloaded.function

What we want to do is to overload

- torch.add
- torch.nn.functional.relu

In [20]:
class CustomLoggingTensor(AbstractTensor):
    # [...] 
    def torch(module):
        We use the @overloaded.module to specify we're writing here
        a function which should overload the function with the same
        name in the <torch> module
        :param module: object which stores the overloading functions

        Note that we used the @staticmethod decorator as we're in a

        def add(x, y):
            You can write the function to overload in the most natural
            way, so this will be called whenever you call torch.add on
            Logging Tensors, and the x and y you get are also Logging
            Tensors, so compared to the @overloaded.method, you see
            that the @overloaded.module does not hook the arguments.
            print("Log function torch.add")
            return x + y

        # Just register it using the module variable
        module.add = add

        def mul(x, y):
            You can also add the @overloaded.function decorator to also
            hook arguments, ie all the LoggingTensor are replaced with
            their child attribute
            print("Log function torch.mul")
            return x * y

        # Just register it using the module variable
        module.mul = mul

        # You can also overload functions in submodules!
        def nn(module):
            The syntax is the same, so @overloaded.module handles recursion
            Note that we don't need to add the @staticmethod decorator

            def functional(module):
                def relu(x):
                    print("Log function torch.nn.functional.relu")
                    return x * (x > 0).float()

                module.relu = relu

            module.functional = functional

        # Modules should be registered just like functions
        module.nn = nn

Note the diffence between def add and def mul: def add doesn't have @ overloaded.function which means that the args inside are not unwrapped: there are CustomLoggingTensors, while in def mul they are unwrapped and replaced by the child attributes, so Torch tensors in our case.

Look how it changes compared to 2.1: the behaviour is not much different but now the functions modified are very precisely targetted:

In [21]:
x = th.tensor([1., 2])
x = LoggingTensor().on(x)

# Default overloading made in 2.1
r = th.div(x, x)

# Targetted overloading made in 2.3
r = th.add(x, x)

Default log torch.div
Log function torch.add

Also, compared to 2.1, we changed the function behaviour: for relu for example instead of running the built-in relu we run x * (x > 0), even if the output is the same. We could have also called inside the native relu if we wanted, provided that we unwrap the args using @ overloaded.function, otherwise we would loop indefinitely.

def functional(module):
    def relu(x):
        print("Log function torch.nn.functional.relu")
        return torch.nn.functional.relu(x)

2.4 Adding custom tensor attributes

Sometimes you need to add special attributes to your Syft Tensor, like the FixedPrecisionTensor which has a field attribute for example:

In [22]:
fp = th.tensor([1., 2]).fix_precision()
print("Field:", fp.child.field)

(Wrapper)>FixedPrecisionTensor>tensor([1000, 2000])
Field: 4611686018427387904

Just declare them in the __init__, like for example a log_max_size:

In [23]:
class CustomLoggingTensor(AbstractTensor):
    def __init__(self, log_max_size=64, **kwargs):
        self.log_max_size = log_max_size
    # [...]

To make sure this value gets correctly added to the response of an operation, when the chain is rebuilt and that a CustomLoggingTensor is wrapped on top of the result, you should declare get_class_attributes:

In [24]:
class CustomLoggingTensor(AbstractTensor):
    # [...]
    def get_class_attributes(self):
        Return all elements which defines an instance of a certain class.
        return {
            'log_max_size': self.log_max_size

2.5 Sending CustomLoggingTensors

Last thing we love to do, is to sent tensors across workers!

To do so, you need to add a serializer and a deserializer to the class:

In [25]:
# Add these new imports
import syft
from syft.workers.abstract import AbstractWorker

class CustomLoggingTensor(AbstractTensor):
    # [...]
    def simplify(tensor: "CustomLoggingTensor") -> tuple:
        """Takes the attributes of a CustomLoggingTensor and saves them in a tuple.

            tensor: a CustomLoggingTensor.

            tuple: a tuple holding the unique attributes of the CustomLoggingTensor.
        chain = None
        if hasattr(tensor, "child"):
            chain = syft.serde._simplify(tensor.child)

        return (

    def detail(worker: AbstractWorker, tensor_tuple: tuple) -> "CustomLoggingTensor":
            This function reconstructs a CustomLoggingTensor given it's attributes in form of a tuple.
                worker: the worker doing the deserialization
                tensor_tuple: a tuple holding the attributes of the CustomLoggingTensor
                CustomLoggingTensor: a CustomLoggingTensor
                shared_tensor = detail(data)

        tensor_id, log_max_size, tags, description, chain = tensor_tuple

        tensor = CustomLoggingTensor(
            id=syft.serde._detail(worker, tensor_id),
            tags=syft.serde._detail(worker, tags),
            description=syft.serde._detail(worker, description),

        if chain is not None:
            chain = syft.serde._detail(worker, chain)
            tensor.child = chain

        return tensor

And to declare this new tensor to the ser/deser module: in serde/

  • Add an import for CustomLoggingTensor
  • Append CustomLoggingTensor to the OBJ_SIMPLIFIER_AND_DETAILERS list

Everyting should now work correctly:

In [26]:
x = th.tensor([1., 2])
x = sy.LoggingTensor().on(x)

p = x.send(alice)
p2 = p + p
x2 = p2.get()

(Wrapper)>[PointerTensor | me:59260307020 -> alice:64823483938]
(Wrapper)>LoggingTensor>tensor([2., 4.])

And here you are, you should now understand all the tools we've builded so that you can easily build new tensor types and focus on their behaviour rather than on their integration in the PySyft library.

Congratulations!!! - Time to Join the Community!

Congratulations on completing this notebook tutorial! If you enjoyed this and would like to join the movement toward privacy preserving, decentralized ownership of AI and the AI supply chain (data), you can do so in the following ways!

Star PySyft on GitHub

The easiest way to help our community is just by starring the Repos! This helps raise awareness of the cool tools we're building.

Join our Slack!

The best way to keep up to date on the latest advancements is to join our community! You can do so by filling out the form at

Join a Code Project!

The best way to contribute to our community is to become a code contributor! At any time you can go to PySyft GitHub Issues page and filter for "Projects". This will show you all the top level Tickets giving an overview of what projects you can join! If you don't want to join a project, but you would like to do a bit of coding, you can also look for more "one off" mini-projects by searching for GitHub issues marked "good first issue".

If you don't have time to contribute to our codebase, but would still like to lend support, you can also become a Backer on our Open Collective. All donations go toward our web hosting and other community expenses such as hackathons and meetups!

OpenMined's Open Collective Page

In [ ]: